
# -*- coding: utf-8 -*-
from hmac import new
from random import sample
from tkinter import HIDDEN
import buffer
import torch
from torch import nn
import torch.multiprocessing as mp
import yaml
import torch.nn.functional as F
import render
import torch.optim as optim
import numpy as np
import time
import gc
import math
import controller
import agent
import evaluation
from torch.utils.tensorboard import SummaryWriter
torch.autograd.set_detect_anomaly(True)


device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Train:
    def __init__(self, cfg, env):

        self.cfg = cfg
        self.env = env
        self.writer = SummaryWriter('runs/'+str(cfg.d) + cfg.algorithm+str(cfg.sample_type))
        self.policy = agent.Ddpg(cfg)
        self.drawer = render.simulate(cfg, env)
        self.simulater = buffer.Buffer(cfg, env)
        self.evaluate_method = evaluation.evaluate(env, cfg, self.drawer, self.simulater, self.writer)

        if cfg.algorithm == "CB":
            self.dominator = controller.CB(cfg)
        
        elif  cfg.algorithm == "cic":
            self.dominator = controller.Cic(cfg)

        elif cfg.algorithm == "dads":
            self.dominator = controller.Dads(cfg)

        elif cfg.algorithm == "diayn":
            self.dominator = controller.Diayn(cfg)

        elif cfg.algorithm == "large_cb":
            self.dominator = controller.Large_CB(cfg)

        else:
            self.dominator = None
            print("controller error")

    def batch_forward(self, cfg, env, trajlist, noise_tensor):

        input_list = trajlist[:-1].reshape(-1, cfg.state_dim+cfg.index_dim)
        noise_tensor = noise_tensor.reshape(-1, 2)

        xyz_action, xyz_action2, _ = self.policy(input_list, offset = noise_tensor)

        post_coord = env.step(input_list, xyz_action)
        post_coord = post_coord.reshape(cfg.ep_len, cfg.sk_num, cfg.state_dim)

        pre_il = torch.zeros((1, cfg.sk_num, cfg.state_dim)).to(device)
        
        grad_traj = torch.cat((pre_il, post_coord), dim = 0)

        return grad_traj

    def train(self, iteration):

        cfg = self.cfg
        env = self.env
        total_time = 0
        for iternal in range(iteration):
            
            print(iternal)

            #long episode
            if cfg.d == 0:
                with torch.no_grad():
                    (idx_set, trajlist, noise_tensor) = self.simulater.make_traj(self.policy, cfg.sk_num)
                grad_traj = self.batch_forward(cfg, env, trajlist, noise_tensor)
                traj_for_train = grad_traj

            #short episode
            elif cfg.d == 1:
                (idx_set, trajlist, noise_tensor) = self.simulater.make_traj(self.policy, cfg.sk_num)
                traj_for_train = trajlist[:, :, :2]

            else:
                traj_for_train = None
                print("d error")
                

            with torch.no_grad():
                time.sleep(0.05)
                self.evaluate_method.eval(self.policy, trajlist, iternal, total_time)
                time.sleep(0.05)

            self.dominator.feature_train(traj_for_train, idx_set)
            self.dominator.eval()
            reward = self.dominator(traj_for_train, idx_set)
            #xyz_action2 = xyz_action2.reshape(cfg.ep_len, cfg.sk_num, cfg.state_dim)
            #neg_qvalue, qloss = policy.queue_forward(trajlist, xyz_action2.detach().clone(), reward)
            #policy.queue_train(qloss)
            #print("qloss = ",qloss)
            neg_qvalue = torch.sum(-reward)

            #if qloss < 0.01:
            #    neg_qvalue, qloss = policy.queue_forward(trajlist, xyz_action2, reward.detach().clone())
            start = time.time()
            self.policy.policy_train(neg_qvalue)
            end = time.time()
            time_spent = end - start
            total_time = total_time + time_spent
            torch.cuda.empty_cache()
            gc.collect()


        if cfg.save == True:
            torch.save(self.policy.state_dict(), 'Parameter/'+ cfg.algorithm  +'.pth')
            torch.save(self.policy.state_dict(), 'Parameter/'+ cfg.algorithm  +'_tmp.pth')


